import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import OrderedDict

#Load Goodman-Bacon decomposition data
def readData():
    file = 'Data/baconData2.csv'
    data = p.read_csv(file)
    data = appendGroupSizes(data)
    
    return data


def getNumWithinGroup(data, treatment, control, stub, includeCO = True):
    col = '{}gp'.format(stub)

    if(treatment!=treatment):
        return 0

    indices = (data.loc[:, col]==str(treatment))|(data.loc[:, col]==str(control))
    tempData = data.loc[indices, :]
    states = set(tempData['state'])
    r = len(states)

    if(not includeCO and 'Colorado' in states):
        r = r-1

    return r

#Add group size to dataframe for plotting purposes
def appendGroupSizes(data):
    rows = []

    stubs = ['bacReg', 'bacVMT', 'bacGas']

    for ind, row in tqdm(data.iterrows(), desc='Appending group size.'):
        tempRow = []
        for stub in stubs:
            tempTreatment = row['{}T'.format(stub)]
            tempControl = row['{}C'.format(stub)]
            groupWCo = getNumWithinGroup(data, tempTreatment, tempControl, stub)
            groupWoCo = getNumWithinGroup(data, tempTreatment, tempControl, stub)
            if(stub=='bacReg'):
                tempRow.append(groupWoCo)
            else:
                tempRow.append(groupWCo)
        rows.append(tempRow)



    rows = np.array(rows)

    for i in range(len(stubs)):
        data.loc[:, '{}GroupSize'.format(stubs[i])] = rows[:, i]

    return data

#Plot goodman-bacon decomposition coefficients
def plotCoefScatter(data, x=6, title=None, show=False, saveFile = 'Plots/Bacon/ImplicitBetas.png', legend = True, limVal = 20, color=None):
    xCol = "impliedBetaVMT"
    yCol = "impliedBetaGas"
    typeCol = 'bacVMT'+'cgroup'
    sizeCol = 'bacVMT'+'GroupSize'
    sizeCol = 'bacVMT'+'S'

    # data.loc[:, weightCol] = data.loc[:, weightCol]/np.nansum(data[weightCol])

    coefData = data.loc[data.loc[:, xCol]==data.loc[:, xCol], :]

    coefData = coefData.replace(0, 1)

    # coef = np.nansum(data.loc[:, coefCol]*data.loc[:, weightCol])

    coefData.loc[:, 'Number of States'] = data.loc[:, sizeCol]
    coefData.loc[:, 'Coefficient Type'] = data.loc[:, typeCol]


    gr = (1+np.sqrt(5))/2
    y = x/gr

    fig, ax = plt.subplots(1,1, figsize=(x,y))
    hueOrder = data.loc[:, typeCol].astype(str)
    hueOrder = sorted(list(set(hueOrder[hueOrder!='nan'])))


    if(legend):
        ax = sns.scatterplot(y=xCol, x=yCol, s = coefData.loc[:, sizeCol]*500000, hue='Coefficient Type', hue_order=hueOrder,  data=coefData)
    else:
        ax = sns.scatterplot(y=xCol, x=yCol, s = coefData.loc[:, sizeCol]*500000,  data=coefData, legend=False, facecolor=color, color=color)


    

    
    if(limVal!=None):
        plt.xlim([-limVal, limVal ])
        plt.ylim([-limVal, limVal])

    xlim = plt.xlim()
    ylim = plt.ylim()

    plt.plot(xlim, [0, 0], color='xkcd:black', zorder=-10)
    plt.plot([0, 0], ylim, color='xkcd:black', zorder=-10)
    
    plt.xlabel('Implicit Elasticity of VMT')
    plt.ylabel('Implicit Elasticity of Gas Use')
    plt.title(title)
    if(legend):
        sns.move_legend(ax, loc='upper left', bbox_to_anchor=(1.02, 1))


    xlim = plt.xlim()
    ylim = plt.ylim()

    plt.plot(xlim, [0, 0], color='xkcd:black', zorder=-10)
    plt.plot([0, 0], ylim, color='xkcd:black', zorder=-10)

    plt.xlim(xlim)
    plt.ylim(ylim)

    if(saveFile!=None):
        plt.savefig(saveFile, bbox_inches='tight')

    if(show):
        plt.show()
    return coefData


data = readData()
plotCoefScatter(data, title='Implicit Elasticity Goodman-Bacon Decomposition')
plotCoefScatter(data, title='Implicit Elasticity Goodman-Bacon Decomposition', saveFile='Plots/Bacon/ImplicitBetas-LargeLim.png', limVal=None)

coefData = data.loc[data.loc[:, 'impliedBetaVMT']==data.loc[:, 'impliedBetaVMT'], :]
types = set(data.loc[:, 'bacVMT'+'cgroup'])


hueOrder = data.loc[:, 'bacVMT'+'cgroup'].astype(str)
hueOrder = sorted(list(set(hueOrder[hueOrder!='nan'])))

i = 0
for type in types:
    if(type==type):
        color = sns.color_palette(n_colors=len(hueOrder))[hueOrder.index(type)]
        
        tempCoefData = data.loc[data.loc[:, 'bacVMT'+'cgroup']==type, :]
        plotCoefScatter(tempCoefData, saveFile='Plots/Bacon/ImplicitBetas-{}.png'.format(i), title='Implicit Elasticity Goodman-Bacon Decomposition\n{}'.format(type), legend=False, color=color)
        i+=1
